{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Z-saUB9-PKrO"
   },
   "source": [
    "# TensorFlow2教程-TF fuction和AutoGraph\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WSVHzfadSCcU"
   },
   "source": [
    "在TensorFlow 2.0中，默认情况下启用了急切执行。 对于用户而言直观且灵活（运行一次性操作更容易，更快），但这可能会牺牲性能和可部署性。\n",
    "\n",
    "要获得最佳性能并使模型可在任何地方部署，请使用tf.function从程序中构建图。 因为有AutoGraph，可以使用tf.function构建高效性能的Python代码，但仍有一些陷阱需要警惕。\n",
    "\n",
    "主要的要点和建议是：\n",
    "\n",
    "不要依赖Python副作用，如对象变异或列表追加。\n",
    "tf.function最适合TensorFlow操作，而不是NumPy操作或Python原语。\n",
    "如有疑问，请使用for x in y idiom。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2519,
     "status": "ok",
     "timestamp": 1560061694048,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "0lo-ivSzO6wo",
    "outputId": "956c066f-b985-4c15-f0ed-1c79168530a0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-beta0\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "#!pip uninstall tensorflow\n",
    "\n",
    "#!pip install tensorflow==2.0.0-beta0\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rF5kXAQusuTS"
   },
   "source": [
    "下面的辅助程序代码，用于演示可能遇到的各种错误。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_2IK6UeITExX"
   },
   "outputs": [],
   "source": [
    "import contextlib\n",
    "\n",
    "# 构建包含上下文管理器的函数，使其可以在with中使用\n",
    "@contextlib.contextmanager\n",
    "def assert_raises(error_class):\n",
    "    try:\n",
    "        yield\n",
    "    except error_class as e:\n",
    "        print('Caught expected exception \\n  {}: {}'.format(error_class, e))\n",
    "    except Exception as e:\n",
    "        print('Got unexpected exception \\n  {}: {}'.format(type(e), e))\n",
    "    else:\n",
    "        raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
    "            error_class))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-s2D0hxBvBKm"
   },
   "source": [
    "一个tf.function定义就像是一个核心TensorFlow操作：可以急切地执行它; 也可以在图表中使用它; 它有梯度; 等等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 67
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2501,
     "status": "ok",
     "timestamp": 1560061694052,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "m0ctt0w6vHKk",
    "outputId": "4ca9475c-44a3-4710-b7bf-a0da0b6ae0dd"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=14, shape=(2, 2), dtype=float32, numpy=\n",
       "array([[2., 2.],\n",
       "       [2., 2.]], dtype=float32)>"
      ]
     },
     "execution_count": 3,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 类似一个tensorflow操作\n",
    "@tf.function\n",
    "def add(a, b):\n",
    "    return a+b\n",
    "\n",
    "add(tf.ones([2,2]), tf.ones([2,2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2488,
     "status": "ok",
     "timestamp": 1560061694053,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "tv6YXLJwvtZZ",
    "outputId": "8a93042f-1398-4dbe-c03a-2ec0e859b196"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=40, shape=(), dtype=float32, numpy=1.0>"
      ]
     },
     "execution_count": 4,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# tf.function操作可以计算梯度\n",
    "@tf.function\n",
    "def add(a, b):\n",
    "    return a+b\n",
    "v = tf.Variable(2.0)\n",
    "with tf.GradientTape() as tape:\n",
    "    res = add(v, 1.0)\n",
    "\n",
    "tape.gradient(res, v) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 84
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2782,
     "status": "ok",
     "timestamp": 1560061694357,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "0YWC-My4vWkz",
    "outputId": "44585da6-79eb-44df-8937-a11ddcba4e15"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=67, shape=(3, 2), dtype=float32, numpy=\n",
       "array([[3., 3.],\n",
       "       [3., 3.],\n",
       "       [3., 3.]], dtype=float32)>"
      ]
     },
     "execution_count": 5,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 可以内嵌调用tf.function\n",
    "@tf.function\n",
    "def dense_layer(x, w, b):\n",
    "    return add(tf.matmul(x, w), b)\n",
    "\n",
    "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Fp-q1QwrzcWI"
   },
   "source": [
    "## 跟踪和多态\n",
    "Python的动态类型意味着您可以使用各种参数类型调用函数，Python将在每个场景中执行不同的操作。\n",
    "\n",
    "另一方面，TensorFlow图需要静态dtypes和形状尺寸。tf.function通过在必要时回溯函数来生成正确的图形来弥补这一差距。大多数使用的微妙tf.function源于这种回归行为。\n",
    "\n",
    "您可以使用不同类型的参数调用函数来查看正在发生的事情。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 168
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2771,
     "status": "ok",
     "timestamp": 1560061694358,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "83W6rdiTyUJG",
    "outputId": "ac5569a7-6f56-473a-b575-7313c060757f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "追踪变量： Tensor(\"a:0\", shape=(), dtype=int32)\n",
      "结果: tf.Tensor(2, shape=(), dtype=int32)\n",
      "\n",
      "追踪变量： Tensor(\"a:0\", shape=(), dtype=float32)\n",
      "结果: tf.Tensor(2.2, shape=(), dtype=float32)\n",
      "\n",
      "追踪变量： Tensor(\"a:0\", shape=(), dtype=string)\n",
      "结果: tf.Tensor(b'cc', shape=(), dtype=string)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 函数的多态\n",
    "@tf.function\n",
    "def double(a):\n",
    "    print('追踪变量：',a)\n",
    "    return a + a\n",
    "\n",
    "print('结果:',double(tf.constant(1)))\n",
    "print()\n",
    "print('结果:',double(tf.constant(1.1)))\n",
    "print()\n",
    "print('结果:',double(tf.constant('c')))\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UI5P2YJ_1kZ8"
   },
   "source": [
    "控制参数类型：\n",
    "创建一个新的tf.function。tf.function确保单独的对象不共享跟踪。\n",
    "使用该get_concrete_function方法获取特定追踪\n",
    "指定input_signature何时调用tf.function以确保仅构建一个功能图。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 171
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2763,
     "status": "ok",
     "timestamp": 1560061694359,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "YSm_G3eQ0H2y",
    "outputId": "b7c1a01e-105c-423b-c673-97f96f8554e2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "构建许可的追踪\n",
      "追踪变量： Tensor(\"a:0\", dtype=string)\n",
      "执行追踪函数\n",
      "tf.Tensor(b'aa', shape=(), dtype=string)\n",
      "tf.Tensor(b'bb', shape=(), dtype=string)\n",
      "使用不合法参数\n",
      "Caught expected exception \n",
      "  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute __inference_double_98 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_98]\n"
     ]
    }
   ],
   "source": [
    "print('构建许可的追踪')\n",
    "double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))\n",
    "print(\"执行追踪函数\")\n",
    "print(double_strings(tf.constant(\"a\")))\n",
    "print(double_strings(a=tf.constant(\"b\")))\n",
    "print(\"使用不合法参数\")\n",
    "with assert_raises(tf.errors.InvalidArgumentError):\n",
    "    double_strings(tf.constant(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 138
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3066,
     "status": "ok",
     "timestamp": 1560061694671,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "S1IzrEGm2foY",
    "outputId": "5a49a349-6b9e-469a-f10a-ab4287960fee"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n",
      "tf.Tensor([4 1], shape=(2,), dtype=int32)\n",
      "Caught expected exception \n",
      "  <class 'ValueError'>: Python inputs incompatible with input_signature: inputs ((<tf.Tensor: id=125, shape=(2, 2), dtype=int32, numpy=\n",
      "array([[1, 2],\n",
      "       [3, 4]], dtype=int32)>,)), input_signature ((TensorSpec(shape=(None,), dtype=tf.int32, name=None),))\n"
     ]
    }
   ],
   "source": [
    "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
    "def next_collatz(x):\n",
    "    print(\"Tracing with\", x)\n",
    "    return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)\n",
    "\n",
    "print(next_collatz(tf.constant([1, 2])))\n",
    "# 只能输入1维向量\n",
    "with assert_raises(ValueError):\n",
    "    next_collatz(tf.constant([[1, 2], [3, 4]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-g6JfGdh4lDU"
   },
   "source": [
    "## 什么时候回溯？\n",
    "多态tf.function通过跟踪生成具体函数的缓存。缓存键实际上是从函数args和kwargs生成的键的元组。为tf.Tensor参数生成的关键是其形状和类型。为Python原语生成的密钥是它的值。对于所有其他Python类型，键都基于对象，id()以便为每个类的实例独立跟踪方法。将来，TensorFlow可以为Python对象添加更复杂的缓存，可以安全地转换为张量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sBW4e19c50j-"
   },
   "source": [
    "## 使用Python参数还是Tensors参数？\n",
    "通常，Python的参数被用来控制超参数和图形的结构-例如，num_layers=10或training=True或nonlinearity='relu'。因此，如果Python参数发生变化，那么必须回溯图。\n",
    "\n",
    "但是，Python参数可能不会用于控制图构造。在这些情况下，Python值的变化可能会触发不必要的回溯。举例来说，这个训练循环，AutoGraph将动态展开。尽管存在多条迹线，但生成的图实际上是相同的，因此这有点低效。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 50
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3054,
     "status": "ok",
     "timestamp": 1560061694671,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "2Vy9Zrkr4LBi",
    "outputId": "ed6cc253-339c-4c15-d57a-6f456ca8da56"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "追踪： num_steps = 10\n",
      "追踪： num_steps = 20\n"
     ]
    }
   ],
   "source": [
    "def train_one_step():\n",
    "    pass\n",
    "\n",
    "@tf.function\n",
    "def train(num_steps):\n",
    "    print(\"追踪： num_steps = {}\".format(num_steps))\n",
    "    for _ in tf.range(num_steps):\n",
    "        train_one_step()\n",
    "\n",
    "train(num_steps=10)\n",
    "train(num_steps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3374,
     "status": "ok",
     "timestamp": 1560061694999,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "yZIGmt776ORD",
    "outputId": "453e97ff-1d67-45fa-f397-d5535fba0da2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "追踪： num_steps = Tensor(\"num_steps:0\", shape=(), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# 使用tensor，同类型不会重复追踪\n",
    "train(num_steps=tf.constant(10))\n",
    "train(num_steps=tf.constant(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3368,
     "status": "ok",
     "timestamp": 1560061695000,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "Bq1MGQQI6ZEk",
    "outputId": "a9a618d9-8203-4e48-dba9-480717f34aa1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "追踪： num_steps = Tensor(\"num_steps:0\", shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 使用tensor，类型不同才会有新的追踪，（前一个单元格已追踪int型，所以该处不追踪）\n",
    "train(num_steps=tf.constant(10, dtype=tf.int32))\n",
    "train(num_steps=tf.constant(20.6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o5KQqFTK79aP"
   },
   "source": [
    "## 副作用 tf.function\n",
    "通常，Python副作用（如打印或变异对象）仅在跟踪期间发生。你怎么能可靠地触发副作用tf.function呢？\n",
    "\n",
    "一般的经验法则是仅使用Python副作用来调试跟踪。但是，TensorFlow操作类似于tf.Variable.assign，tf.print并且tf.summary是确保TensorFlow运行时在每次调用时跟踪和执行代码的最佳方法。通常使用功能样式将产生最佳结果。\n",
    "\n",
    "tf.function函数中的print()被用于跟踪，所以要调试输出每次调用(副作用),就需要tf.function()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nNE9tUEU6fzd"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def f(x):\n",
    "    print(\"追踪：\", x)\n",
    "    tf.print('执行：', x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 101
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 730,
     "status": "ok",
     "timestamp": 1560062207799,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "AIbz4shg8-TM",
    "outputId": "1e0f69aa-58c1-4160-b341-847959f204a7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "追踪： 1\n",
      "执行： 1\n",
      "执行： 1\n",
      "追踪： 2\n",
      "执行： 2\n"
     ]
    }
   ],
   "source": [
    "f(1)\n",
    "f(1)\n",
    "f(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lhGssQgi9dZ8"
   },
   "source": [
    "如果想在每次调用期间执行Python代码tf.function，可以使用tf.py_function。tf.py_function缺点是它不便携和高效，也不能在分布式（多GPU，TPU）设置中很好地工作。此外，由于tf.py_function必须连接到图，它将所有输入/输出转换为张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 171
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 732,
     "status": "ok",
     "timestamp": 1560062465050,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "E64cIU0H9CQu",
    "outputId": "7d5414cc-807e-4fc4-d84d-684af63fb851"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0609 06:41:05.048375 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32\n",
      "W0609 06:41:05.053524 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32\n",
      "W0609 06:41:05.056409 139792226170624 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Python side effect\n",
      "Python side effect\n",
      "Python side effect\n",
      "[<tf.Tensor: id=351, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=352, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=353, shape=(), dtype=int32, numpy=1>]\n"
     ]
    }
   ],
   "source": [
    "external_list = []\n",
    "\n",
    "def side_effect(x):\n",
    "    print('Python side effect')\n",
    "    external_list.append(x)\n",
    "\n",
    "@tf.function\n",
    "def f(x):\n",
    "    tf.py_function(side_effect, inp=[x], Tout=[])\n",
    "\n",
    "f(1)\n",
    "f(1)\n",
    "f(1)\n",
    "print(external_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qEvUyInkCWJN"
   },
   "source": [
    "## 谨防Python状态\n",
    "许多Python功能（如生成器和迭代器）依赖于Python运行时来跟踪状态。 通常，虽然这些构造在Eager模式下按预期工作，但由于跟踪行为，tf.function内部可能会发生许多意外情况。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 67
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 904,
     "status": "ok",
     "timestamp": 1560063871419,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "APe8jOwi-BES",
    "outputId": "4ed69333-d8b7-4e8e-c23d-155e47fd811d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "external_var: 0\n",
      "external_var: 0\n",
      "external_var: 0\n"
     ]
    }
   ],
   "source": [
    "external_var = tf.Variable(0)\n",
    "@tf.function\n",
    "def buggy_consume_next(iterator):\n",
    "    external_var.assign_add(next(iterator))\n",
    "    tf.print('external_var:', external_var)\n",
    "    \n",
    "iterator = iter([0,1,2,3])\n",
    "buggy_consume_next(iterator)\n",
    "# 后面没有正常迭代，输出的都是第一个\n",
    "buggy_consume_next(iterator)\n",
    "buggy_consume_next(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eBOKO-DZDlcU"
   },
   "source": [
    "如果在tf.function中生成并完全使用了迭代器，那么它应该可以正常工作。但是，整个迭代器可能正在被跟踪，这可能导致一个巨大的图。如果正在训练一个表示为Python列表的大型内存数据集，那么这会生成一个非常大的图，并且tf.function不太可能产生加速。\n",
    "\n",
    "如果要迭代Python数据，最安全的方法是将其包装在tf.data.Dataset中并使用该for x in y惯用法。AutoGraph特别支持for在y张量或tf.data.Dataset 时安全地转换循环。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 104
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 954,
     "status": "ok",
     "timestamp": 1560064507338,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "A5yosSxIDMMj",
    "outputId": "39b311d3-c6b6-4139-9b61-ccb3c803e107"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train([(1, 1), (1, 1)]) 的图中包含了 8 个节点\n",
      "train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) 的图中包含了 32 个节点\n",
      "train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) 的图中包含了 4 个节点\n",
      "train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) 的图中包含了 4 个节点\n"
     ]
    }
   ],
   "source": [
    "def measure_graph_size(f, *args):\n",
    "    g = f.get_concrete_function(*args).graph\n",
    "    print(\"{}({}) 的图中包含了 {} 个节点\".format(\n",
    "      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n",
    "\n",
    "@tf.function\n",
    "def train(dataset):\n",
    "    loss = tf.constant(0)\n",
    "    for x, y in dataset:\n",
    "        loss += tf.abs(y - x) # Some dummy computation.\n",
    "    return loss\n",
    "\n",
    "small_data = [(1, 1)] * 2\n",
    "big_data = [(1, 1)] * 10\n",
    "measure_graph_size(train, small_data)\n",
    "measure_graph_size(train, big_data)\n",
    "\n",
    "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
    "    lambda: small_data, (tf.int32, tf.int32)))\n",
    "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
    "    lambda: big_data, (tf.int32, tf.int32)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "u_c0lfVhF-MU"
   },
   "source": [
    "在数据集中包装Python / Numpy数据时，请注意tf.data.Dataset.from_generator与tf.data.Dataset.from_tensors。前者将数据保存在Python中并通过tf.py_function它获取性能影响，而后者将数据的副本捆绑为图中的一个大tf.constant()节点，这可能会对内存产生影响。\n",
    "\n",
    "通过TFRecordDataset / CsvDataset / etc从文件中读取数据。是最有效的数据处理方式，因为TensorFlow本身可以管理数据的异步加载和预取，而不必涉及Python。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TEMIwcR9GZNS"
   },
   "source": [
    "## 自动控制依赖项\n",
    "在一般数据流图上，作为编程模型的函数的一个非常吸引人的特性是函数可以为运行时提供有关代码的预期行为的更多信息。\n",
    "\n",
    "例如，当编写具有多个读取和写入相同变量的代码时，数据流图可能不会自然地编码最初预期的操作顺序。在tf.function，我们通过引用原始Python代码中的语句的执行顺序来解决执行顺序中的歧义。这样，有序状态操作的排序tf.function复制了Eager模式的语义。\n",
    "\n",
    "这意味着不需要添加手动控制依赖项; tf.function足够聪明，可以为代码添加最小的必要和充分的控制依赖关系，以便正确运行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 817,
     "status": "ok",
     "timestamp": 1560064847261,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "1hZ5UACIFOwG",
    "outputId": "f0ed92ab-024a-4c06-b9ea-53a280adf9cd"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=739, shape=(), dtype=float32, numpy=10.0>"
      ]
     },
     "execution_count": 24,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 按顺序自动执行\n",
    "a = tf.Variable(1.0)\n",
    "b = tf.Variable(2.0)\n",
    "\n",
    "@tf.function\n",
    "def f(x, y):\n",
    "    a.assign(y * b)\n",
    "    b.assign_add(x * a)\n",
    "    return a + b\n",
    "\n",
    "f(1.0, 2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sCyejiCzHg9s"
   },
   "source": [
    "变量\n",
    "我们可以使用相同的想法来利用代码的预期执行顺序，使变量创建和利用变得非常容易tf.function。但是有一个非常重要的警告，即使用变量，可以编写在急切模式和图形模式下表现不同的代码。\n",
    "\n",
    "具体来说，每次调用创建一个新变量时都会发生这种情况。由于跟踪语义，tf.function每次调用都会重用相同的变量，但是eager模式会在每次调用时创建一个新变量。为防止出现此错误，tf.function如果检测到危险变量创建行为，则会引发错误。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 306
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1011,
     "status": "ok",
     "timestamp": 1560065070996,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "MvEHlVfOHsY2",
    "outputId": "78e52b1a-9160-4dba-f12a-27117d314cf7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'ValueError'>: in converted code:\n",
      "\n",
      "    <ipython-input-25-8e74447e7577>:4 f  *\n",
      "        v = tf.Variable(1.0)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__\n",
      "        return cls._variable_v2_call(*args, **kwargs)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call\n",
      "        shape=shape)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:60 getter\n",
      "        return captured_getter(captured_previous, **kwargs)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:364 invalid_creator_scope\n",
      "        \"tf.function-decorated function tried to create \"\n",
      "\n",
      "    ValueError: tf.function-decorated function tried to create variables on non-first call.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def f(x):\n",
    "    # tf.function会重复调用相同变量，而eager每次都会创建新的变量\n",
    "    v = tf.Variable(1.0)\n",
    "    v.assign_add(x)\n",
    "    return v\n",
    "\n",
    "with assert_raises(ValueError):\n",
    "    f(1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UB6kOLf2Hqim"
   },
   "source": [
    "不会报错的方法是"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 50
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 827,
     "status": "ok",
     "timestamp": 1560065172070,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "-HPVyn4nINGe",
    "outputId": "16c68f4c-11fb-4829-e3fa-ff4e539af814"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(2.0, shape=(), dtype=float32)\n",
      "tf.Tensor(4.0, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "v = tf.Variable(1.0)  # 把变量拿到tf.function外面\n",
    "\n",
    "@tf.function\n",
    "def f(x):\n",
    "    return v.assign_add(x)\n",
    "\n",
    "print(f(1.0))  # 2.0\n",
    "print(f(2.0))  # 4.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "X8lVcCU4IJk7"
   },
   "source": [
    "也可以在tf.function中创建变量，只要可以保证这些变量仅在第一次执行函数时创建。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 50
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 839,
     "status": "ok",
     "timestamp": 1560065271907,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "Os5GPslSG6_3",
    "outputId": "56d6f91d-4a8d-4712-e0a6-8a54bc5e3372"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(2.0, shape=(), dtype=float32)\n",
      "tf.Tensor(4.0, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "class C: pass\n",
    "obj = C(); obj.v = None\n",
    "\n",
    "@tf.function\n",
    "def g(x):\n",
    "    if obj.v is None:\n",
    "        obj.v = tf.Variable(1.0)\n",
    "    return obj.v.assign_add(x)\n",
    "\n",
    "print(g(1.0))  # 2.0\n",
    "print(g(2.0))  # 4.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "enPov9K2IvkW"
   },
   "source": [
    "变量初始值设定项可以依赖于函数参数和其他变量的值。 我们可以使用与生成控制依赖关系相同的方法找出正确的初始化顺序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 50
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1182,
     "status": "ok",
     "timestamp": 1560065512639,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "NlpF8mUYJl86",
    "outputId": "3caac82d-395f-4d42-a28e-b5057b40d66f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(12.0, shape=(), dtype=float32)\n",
      "tf.Tensor(36.0, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "state = []\n",
    "@tf.function\n",
    "def fn(x):\n",
    "    if not state:\n",
    "        state.append(tf.Variable(2.0 * x))\n",
    "        state.append(tf.Variable(state[0] * 3.0))\n",
    "    return state[0] * x * state[1]\n",
    "\n",
    "print(fn(tf.constant(1.0)))\n",
    "print(fn(tf.constant(3.0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "t6i29GwEJlAk"
   },
   "source": [
    "## 使用AutoGraph\n",
    "该签名库完全集成tf.function，它将改写条件和循环依赖于张量在图形动态运行。\n",
    "\n",
    "tf.cond并且tf.while_loop继续使用tf.function，但是当以命令式样式编写时，具有控制流的代码通常更容易编写和理解。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 655
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 938,
     "status": "ok",
     "timestamp": 1560065826332,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "lKn39wWTKMQB",
    "outputId": "76e9991c-90f8-4d43-ca9b-8a8a74ca9960"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.829342961 0.858322263 0.900950909 0.851897 0.530384183]\n",
      "[0.680123031 0.695392191 0.716760576 0.692059278 0.485674709]\n",
      "[0.591599405 0.601434886 0.614898741 0.599303305 0.450776756]\n",
      "[0.53104496 0.538069844 0.547566235 0.536553681 0.422537297]\n",
      "[0.486179501 0.491525501 0.498693913 0.490374774 0.399065822]\n",
      "[0.451178908 0.455426365 0.461089343 0.454513818 0.379149348]\n",
      "[0.422867566 0.426349223 0.430971652 0.425602287 0.361968517]\n",
      "[0.399343461 0.402265817 0.406133026 0.401639521 0.346946776]\n",
      "[0.379387051 0.381885976 0.385184318 0.381350905 0.333665]\n",
      "[0.362175018 0.36434418 0.367201209 0.363880038 0.321810097]\n",
      "[0.347128421 0.349034756 0.351541221 0.348627061 0.311142713]\n",
      "[0.333826423 0.335519224 0.337741673 0.335157365 0.30147627]\n",
      "[0.321954757 0.323471278 0.325459719 0.323147237 0.292663]\n",
      "[0.311273336 0.312642276 0.314435244 0.312349856 0.284584]\n",
      "[0.301595032 0.302838922 0.304466605 0.302573323 0.277142316]\n",
      "[0.292771578 0.293908447 0.295394808 0.293665737 0.270258158]\n",
      "[0.284683794 0.285728157 0.287092626 0.285505235 0.263865024]\n",
      "[0.277234435 0.278198302 0.279456645 0.277992576 0.257907033]\n",
      "[0.270343572 0.271236718 0.272402078 0.271046132 0.25233686]\n",
      "[0.263944477 0.264775217 0.265858531 0.264597982 0.247114092]\n",
      "[0.257981181 0.258756459 0.259766966 0.258591145 0.242203966]\n",
      "[0.252406299 0.253132015 0.254077554 0.252977312 0.237576365]\n",
      "[0.24717927 0.247860536 0.248747766 0.247715324 0.233205199]\n",
      "[0.242265314 0.242906466 0.24374117 0.242769822 0.229067564]\n",
      "[0.237634286 0.238239139 0.239026278 0.238110229 0.225143358]\n",
      "[0.233259991 0.233831868 0.234575793 0.233709976 0.221414775]\n",
      "[0.229119495 0.229661271 0.230365857 0.229545817 0.217866093]\n",
      "[0.225192651 0.22570689 0.22637549 0.225597292 0.214483246]\n",
      "[0.221461684 0.221950635 0.222586185 0.221846417 0.211253688]\n",
      "[0.217910782 0.218376443 0.218981609 0.218277216 0.208166167]\n",
      "[0.214525893 0.214970052 0.215547174 0.214875415 0.205210552]\n",
      "[0.211294428 0.211718708 0.212269917 0.211628318 0.202377662]\n",
      "[0.208205134 0.208611 0.209138155 0.20852454 0.199659243]\n",
      "[0.205247864 0.205636591 0.206141427 0.2055538 0.197047815]\n",
      "[0.20241344 0.202786222 0.203270242 0.202706844 0.194536477]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1006, shape=(5,), dtype=float32, numpy=\n",
       "array([0.19969359, 0.2000515 , 0.2005161 , 0.19997531, 0.192119  ],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 30,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 简单的循环\n",
    "@tf.function\n",
    "def f(x):\n",
    "    # 直接用python中的while写循环\n",
    "    while tf.reduce_sum(x) > 1:\n",
    "        tf.print(x)\n",
    "        x = tf.tanh(x)\n",
    "    return x\n",
    "f(tf.random.uniform([5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1101,
     "status": "ok",
     "timestamp": 1560065918037,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "Pf9vcB5qIuTj",
    "outputId": "3b0c472c-ce7b-4d18-b6c2-0721e33264ab"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tensorflow.python.eager.def_function.Function object at 0x7f23e7df2240>\n"
     ]
    }
   ],
   "source": [
    "print(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "x6m7Js_-LXqM"
   },
   "source": [
    "可以检查代码签名生成。 但感觉就像阅读汇编语言一样。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 558
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1260,
     "status": "ok",
     "timestamp": 1560065998816,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "MTgjr-RHLcy8",
    "outputId": "9cc417bc-e624-4528-96d2-be1d78a3f6b2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def tf__f(x):\n",
      "  do_return = False\n",
      "  retval_ = ag__.UndefinedReturnValue()\n",
      "\n",
      "  def loop_test(x_1):\n",
      "    return ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None) > 1\n",
      "\n",
      "  def loop_body(x_1):\n",
      "    ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)\n",
      "    x_1 = ag__.converted_call('tanh', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)\n",
      "    return x_1,\n",
      "  x, = ag__.while_stmt(loop_test, loop_body, (x,))\n",
      "  do_return = True\n",
      "  retval_ = x\n",
      "  cond = ag__.is_undefined_return(retval_)\n",
      "\n",
      "  def get_state():\n",
      "    return ()\n",
      "\n",
      "  def set_state(_):\n",
      "    pass\n",
      "\n",
      "  def if_true():\n",
      "    retval_ = None\n",
      "    return retval_\n",
      "\n",
      "  def if_false():\n",
      "    return retval_\n",
      "  retval_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)\n",
      "  return retval_\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def f(x):\n",
    "    while tf.reduce_sum(x) > 1:\n",
    "        tf.print(x)\n",
    "        x = tf.tanh(x)\n",
    "    return x\n",
    "\n",
    "print(tf.autograph.to_code(f))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GNaylIavLzFd"
   },
   "source": [
    "AutoGraph：条件\n",
    "AutoGraph会将if语句转换为等效的tf.cond调用。\n",
    "\n",
    "如果条件是Tensor，则进行此替换。否则，在跟踪期间执行条件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FbkERy0-L4JI"
   },
   "outputs": [],
   "source": [
    "# 测试\n",
    "def test_tf_cond(f, *args):\n",
    "    # 获取图\n",
    "    g = f.get_concrete_function(*args).graph\n",
    "    if any(node.name=='cond' for node in g.as_graph_def().node):\n",
    "        print(\"{}({}) 使用 tf.cond.\".format(\n",
    "        f.__name__, ', '.join(map(str, args))))\n",
    "    else:\n",
    "        print(\"{}({}) 正常执行.\".format(\n",
    "            f.__name__, ', '.join(map(str, args))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PtVaFUIfM-Zq"
   },
   "source": [
    "只有条件为tensor，才会使用tf.cond"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 67
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1043,
     "status": "ok",
     "timestamp": 1560066497194,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "8uJQ8aVENPNs",
    "outputId": "57df931f-0a41-41c1-e6dc-cd32cc6a40fd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) 正常执行.\n",
      "maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) 使用 tf.cond.\n",
      "maybe_tensor_cond(-1) 正常执行.\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def hyperparam_cond(x, training=True):\n",
    "    if training:\n",
    "        x = tf.nn.dropout(x, rate=0.5)\n",
    "    return x\n",
    "\n",
    "@tf.function\n",
    "def maybe_tensor_cond(x):\n",
    "    if x < 0:\n",
    "        x = -x\n",
    "    return x\n",
    "\n",
    "test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))\n",
    "test_tf_cond(maybe_tensor_cond, tf.constant(-1)) # 条件为tensor\n",
    "test_tf_cond(maybe_tensor_cond, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fT9Eeg80NrN5"
   },
   "source": [
    "tf.cond有一些细微之处。 - 它的工作原理是跟踪条件的两边，然后根据条件在运行时选择适当的分支。跟踪双方可能导致意外执行Python代码 - 它要求如果一个分支创建下游使用的张量，另一个分支也必须创建该张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 84
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 911,
     "status": "ok",
     "timestamp": 1560066757390,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "BIOripk2Nqbv",
    "outputId": "23be71ec-0b08-4028-f65c-e1b50da23f65"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing `then` branch\n",
      "Tracing `else` branch\n",
      "执行，x： 1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1128, shape=(), dtype=int32, numpy=1>"
      ]
     },
     "execution_count": 39,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def f():\n",
    "    x = tf.constant(0)\n",
    "    if tf.constant(True): \n",
    "        x = x + 1\n",
    "        tf.print('执行，x：', x)\n",
    "        print(\"Tracing `then` branch\")\n",
    "    else:\n",
    "        x = x - 1\n",
    "        tf.print('执行，x：', x)  # 没有执行\n",
    "        print(\"Tracing `else` branch\")  # 该分支虽然不执行但也被追踪\n",
    "    return x\n",
    "\n",
    "f()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GrqHlCpmLb-M"
   },
   "source": [
    "两个分支必须都定义x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 440
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 847,
     "status": "ok",
     "timestamp": 1560066896672,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "SKnZT9UyOqtS",
    "outputId": "75b09b9f-dad1-4cb3-fb78-459db7d791f7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'ValueError'>: in converted code:\n",
      "\n",
      "    <ipython-input-40-c7af591027c1>:3 f  *\n",
      "        if tf.constant(True):\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:439 if_stmt\n",
      "        return tf_if_stmt(cond, body, orelse, get_state, set_state)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:456 tf_if_stmt\n",
      "        outputs, final_state = control_flow_ops.cond(cond, body, orelse)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:507 new_func\n",
      "        return func(*args, **kwargs)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:1147 cond\n",
      "        return cond_v2.cond_v2(pred, true_fn, false_fn, name)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:86 cond_v2\n",
      "        op_return_value=pred)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:716 func_graph_from_py_func\n",
      "        func_outputs = python_func(*func_args, **func_kwargs)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:486 wrapper\n",
      "        outputs = func()\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:512 wrapper\n",
      "        tuple(s.symbol_name for s in undefined)))\n",
      "\n",
      "    ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def f():\n",
    "    if tf.constant(True):\n",
    "        x = tf.ones([3, 3])\n",
    "    return x\n",
    "\n",
    "# 两个分支必须都定义x， 否则会抛出异常\n",
    "with assert_raises(ValueError):\n",
    "    f()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aDnjKM7EPNBK"
   },
   "source": [
    "AutoGraph和循环\n",
    "AutoGraph有一些简单的转换循环规则。\n",
    "\n",
    "- for：如果iterable是张量，则转换\n",
    "- while：如果while条件取决于张量，则转换\n",
    "\n",
    "\n",
    "如果循环被转换，它将被动态展开tf.while_loop，或者在a的特殊情况下for x in tf.data.Dataset转换为tf.data.Dataset.reduce。\n",
    "\n",
    "如果未转换循环，则将静态展开"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "P5X7vKnTPcDa"
   },
   "outputs": [],
   "source": [
    "# 测试\n",
    "def test_dynamically_unrolled(f, *args):\n",
    "    g = f.get_concrete_function(*args).graph\n",
    "    if any(node.name == 'while' for node in g.as_graph_def().node):\n",
    "        print(\"{}({}) uses tf.while_loop.\".format(\n",
    "            f.__name__, ', '.join(map(str, args))))\n",
    "    elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):\n",
    "        print(\"{}({}) uses tf.data.Dataset.reduce.\".format(\n",
    "            f.__name__, ', '.join(map(str, args))))\n",
    "    else:\n",
    "        print(\"{}({}) gets unrolled.\".format(\n",
    "            f.__name__, ', '.join(map(str, args))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 67
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1083,
     "status": "ok",
     "timestamp": 1560067150414,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "XP8vb1SvPvJE",
    "outputId": "9e9c3366-ff40-44b2-d7ef-6f797e251633"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for_in_range() gets unrolled.\n",
      "for_in_tfrange() uses tf.while_loop.\n",
      "for_in_tfdataset() uses tf.data.Dataset.reduce.\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def for_in_range():\n",
    "    x = 0\n",
    "    for i in range(5):\n",
    "        x += i\n",
    "    return x\n",
    "\n",
    "@tf.function\n",
    "def for_in_tfrange():\n",
    "    x = tf.constant(0, dtype=tf.int32)\n",
    "    for i in tf.range(5):  # 生成迭代的张量\n",
    "        x += i\n",
    "    return x\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def for_in_tfdataset():\n",
    "    x = tf.constant(0, dtype=tf.int64)\n",
    "    for i in tf.data.Dataset.range(5):\n",
    "        x += i\n",
    "    return x\n",
    "\n",
    "test_dynamically_unrolled(for_in_range)\n",
    "test_dynamically_unrolled(for_in_tfrange)\n",
    "test_dynamically_unrolled(for_in_tfdataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 50
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1310,
     "status": "ok",
     "timestamp": 1560067295790,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "oPAm3JkDQTVY",
    "outputId": "fd39137d-88e5-4bd1-e27a-1b30e1822b8e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "while_py_cond() gets unrolled.\n",
      "while_tf_cond() uses tf.while_loop.\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def while_py_cond():\n",
    "    x = 5\n",
    "    while x > 0:\n",
    "        x -= 1\n",
    "    return x\n",
    "\n",
    "@tf.function\n",
    "def while_tf_cond():\n",
    "    x = tf.constant(5)\n",
    "    while x > 0:   # while中的x为张量\n",
    "        x -= 1\n",
    "    return x\n",
    "\n",
    "test_dynamically_unrolled(while_py_cond)\n",
    "test_dynamically_unrolled(while_tf_cond)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_yYDbvEoPaIs"
   },
   "source": [
    "如果有一个break或早期的return子句依赖于张量，那么顶级条件或者iterable也应该是一个张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 289
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1250,
     "status": "ok",
     "timestamp": 1560067489145,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "xM9fdNvRQym6",
    "outputId": "210c62b9-7f4a-4c65-8a7f-b37b6bea7946"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'TypeError'>: in converted code:\n",
      "\n",
      "    <ipython-input-45-f42fe93cfd97>:3 buggy_while_py_true_tf_break  *\n",
      "        while True:\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:313 while_stmt\n",
      "        return _py_while_stmt(test, body, init_state, opts)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:401 _py_while_stmt\n",
      "        while test(*state):\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__\n",
      "        raise TypeError(\"Using a `tf.Tensor` as a Python `bool` is not allowed. \"\n",
      "\n",
      "    TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.\n",
      "\n",
      "while_tf_true_tf_break(5) uses tf.while_loop.\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def buggy_while_py_true_tf_break(x):\n",
    "    while True:\n",
    "        if tf.equal(x, 0):\n",
    "            break\n",
    "        x -= 1\n",
    "    return x\n",
    "\n",
    "@tf.function\n",
    "def while_tf_true_tf_break(x):\n",
    "    while tf.constant(True):  # 有break，顶级条件必须为张量\n",
    "        if tf.equal(x, 0):\n",
    "            break\n",
    "        x -= 1\n",
    "    return x\n",
    "\n",
    "with assert_raises(TypeError):\n",
    "    test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)\n",
    "test_dynamically_unrolled(while_tf_true_tf_break, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 289
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1223,
     "status": "ok",
     "timestamp": 1560067582158,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "0hkyGy13RU1Q",
    "outputId": "cbfe0309-f66c-472c-f5ec-1be88afbf255"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'TypeError'>: in converted code:\n",
      "\n",
      "    <ipython-input-46-902b45f3c32e>:4 buggy_py_for_tf_break  *\n",
      "        for i in range(5):\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:110 for_stmt\n",
      "        return _py_for_stmt(iter_, extra_test, body, init_state)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:117 _py_for_stmt\n",
      "        if extra_test is not None and not extra_test(*state):\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__\n",
      "        raise TypeError(\"Using a `tf.Tensor` as a Python `bool` is not allowed. \"\n",
      "\n",
      "    TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.\n",
      "\n",
      "tf_for_tf_break() uses tf.while_loop.\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def buggy_py_for_tf_break():\n",
    "    x = 0\n",
    "    for i in range(5):\n",
    "        if tf.equal(i, 3):\n",
    "            break\n",
    "        x += i\n",
    "    return x\n",
    "\n",
    "@tf.function\n",
    "def tf_for_tf_break():\n",
    "    x = 0\n",
    "    for i in tf.range(5):  # 有break，顶级迭代器必须为张量\n",
    "        if tf.equal(i, 3):\n",
    "            break\n",
    "        x += i\n",
    "    return x\n",
    "\n",
    "with assert_raises(TypeError):\n",
    "    test_dynamically_unrolled(buggy_py_for_tf_break)\n",
    "test_dynamically_unrolled(tf_for_tf_break)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pD_vDEaGR4BW"
   },
   "source": [
    "为了累积动态展开循环的结果，需要使用tf.TensorArray。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 2167
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 897,
     "status": "ok",
     "timestamp": 1560069110088,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "B488EpdZROn5",
    "outputId": "860a02e0-a31c-4e97-b9bd-20f3e432c7df"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1998, shape=(32, 3, 4), dtype=float32, numpy=\n",
       "array([[[0.42647886, 0.73600817, 0.10211909, 0.89989746],\n",
       "        [0.772506  , 1.6853498 , 0.48793948, 1.4499462 ],\n",
       "        [1.1096102 , 2.3388233 , 0.5920907 , 1.588302  ]],\n",
       "\n",
       "       [[0.45684695, 0.955214  , 0.9993408 , 0.8219656 ],\n",
       "        [1.2537256 , 1.2735902 , 1.645985  , 1.4104882 ],\n",
       "        [1.8525792 , 2.168501  , 1.6770781 , 2.3911076 ]],\n",
       "\n",
       "       [[0.5406686 , 0.67578137, 0.4403913 , 0.67319834],\n",
       "        [0.58906066, 1.268778  , 1.439989  , 1.4741062 ],\n",
       "        [0.9815061 , 1.4514081 , 2.3673592 , 1.6882598 ]],\n",
       "\n",
       "       [[0.855492  , 0.28206396, 0.24999726, 0.36174345],\n",
       "        [1.5141011 , 0.89166963, 0.8607675 , 0.98511755],\n",
       "        [2.4110565 , 0.9396002 , 0.9327749 , 1.5260868 ]],\n",
       "\n",
       "       [[0.81104064, 0.83997786, 0.15997863, 0.4358573 ],\n",
       "        [1.1134938 , 1.4853268 , 0.54013836, 0.9705579 ],\n",
       "        [1.5386839 , 2.316008  , 0.9928701 , 1.0856854 ]],\n",
       "\n",
       "       [[0.29561853, 0.14397347, 0.57903993, 0.03648317],\n",
       "        [0.96985126, 0.15989316, 0.9244931 , 0.85591257],\n",
       "        [1.6263086 , 0.33901858, 1.2374965 , 1.7949932 ]],\n",
       "\n",
       "       [[0.9747442 , 0.11894786, 0.7442424 , 0.97833633],\n",
       "        [1.3285936 , 0.65585303, 0.9375925 , 1.6591501 ],\n",
       "        [1.9431509 , 1.3246762 , 1.770958  , 1.8282531 ]],\n",
       "\n",
       "       [[0.77624667, 0.88325894, 0.73874867, 0.4280392 ],\n",
       "        [1.5642238 , 0.99281085, 0.9649793 , 0.6810435 ],\n",
       "        [2.3166971 , 1.3342402 , 1.2706931 , 0.8359492 ]],\n",
       "\n",
       "       [[0.5189506 , 0.17288852, 0.0951364 , 0.1809932 ],\n",
       "        [0.857113  , 0.5164219 , 0.64073265, 0.7035332 ],\n",
       "        [1.315091  , 0.7406205 , 1.2290434 , 1.1238463 ]],\n",
       "\n",
       "       [[0.6529721 , 0.29452014, 0.52032673, 0.6763296 ],\n",
       "        [1.3331723 , 1.1131297 , 0.7139894 , 0.7117865 ],\n",
       "        [2.2893934 , 1.5741713 , 1.3881162 , 1.1827962 ]],\n",
       "\n",
       "       [[0.1379869 , 0.39929485, 0.10332012, 0.36433125],\n",
       "        [0.7629658 , 1.0405501 , 1.0723822 , 0.98746634],\n",
       "        [1.307274  , 1.9727362 , 1.9607428 , 0.99152625]],\n",
       "\n",
       "       [[0.09253156, 0.1335038 , 0.6397847 , 0.64935887],\n",
       "        [0.74003506, 0.7885244 , 1.5264978 , 1.2313842 ],\n",
       "        [1.3094535 , 0.9897822 , 2.113062  , 1.7757427 ]],\n",
       "\n",
       "       [[0.43634772, 0.19575393, 0.7696773 , 0.27747822],\n",
       "        [0.91441095, 0.33673525, 1.7044361 , 0.64520323],\n",
       "        [1.4157002 , 1.108459  , 1.9478035 , 1.2194656 ]],\n",
       "\n",
       "       [[0.95848894, 0.8397597 , 0.40994358, 0.38172424],\n",
       "        [1.3659054 , 1.8318363 , 0.44894862, 0.573074  ],\n",
       "        [2.0646758 , 2.644064  , 1.0859761 , 1.0852034 ]],\n",
       "\n",
       "       [[0.06133533, 0.4990101 , 0.8753276 , 0.85797477],\n",
       "        [1.0456792 , 1.3709209 , 1.2738711 , 1.5578356 ],\n",
       "        [1.2093625 , 2.3393312 , 1.6527376 , 1.9759873 ]],\n",
       "\n",
       "       [[0.9180386 , 0.8334261 , 0.50366354, 0.6812706 ],\n",
       "        [1.4707679 , 1.6046176 , 0.7666013 , 1.0866822 ],\n",
       "        [2.4036891 , 2.3536391 , 1.5735171 , 1.908121  ]],\n",
       "\n",
       "       [[0.28029943, 0.19160116, 0.8693464 , 0.634127  ],\n",
       "        [0.36180413, 0.4781108 , 1.6698829 , 0.71363425],\n",
       "        [0.9457128 , 1.236153  , 2.3801937 , 1.6479315 ]],\n",
       "\n",
       "       [[0.9700936 , 0.05327344, 0.9682101 , 0.68595064],\n",
       "        [1.5817094 , 0.6396173 , 1.4235687 , 0.9771451 ],\n",
       "        [2.0596964 , 1.405353  , 2.3626194 , 1.8783083 ]],\n",
       "\n",
       "       [[0.93972707, 0.66909564, 0.47730482, 0.8532878 ],\n",
       "        [1.9121101 , 1.2118778 , 1.3712984 , 0.9025755 ],\n",
       "        [2.8286343 , 2.0688076 , 1.918565  , 1.3270372 ]],\n",
       "\n",
       "       [[0.58233047, 0.19796038, 0.8402959 , 0.39909327],\n",
       "        [1.1069249 , 0.27517664, 1.619991  , 1.2409564 ],\n",
       "        [1.6110749 , 0.36962962, 1.7012888 , 1.9015294 ]],\n",
       "\n",
       "       [[0.17370081, 0.04553473, 0.24088955, 0.98457193],\n",
       "        [0.5662674 , 0.8119199 , 0.6683898 , 1.7632768 ],\n",
       "        [1.233774  , 1.3301728 , 1.1902118 , 2.048545  ]],\n",
       "\n",
       "       [[0.49256516, 0.15658045, 0.9459133 , 0.15301299],\n",
       "        [0.7268286 , 0.54530656, 1.7935027 , 0.8396299 ],\n",
       "        [1.1115954 , 0.94205654, 2.678388  , 1.6442738 ]],\n",
       "\n",
       "       [[0.41741276, 0.41045308, 0.76129663, 0.55115104],\n",
       "        [1.119844  , 0.89330614, 1.6357045 , 0.89033365],\n",
       "        [1.8772408 , 1.8043301 , 2.5606012 , 1.5782887 ]],\n",
       "\n",
       "       [[0.7357129 , 0.58418286, 0.8029188 , 0.16406012],\n",
       "        [1.4910457 , 0.76808226, 1.4812832 , 0.6496997 ],\n",
       "        [2.073011  , 1.3945519 , 1.8739614 , 0.76415455]],\n",
       "\n",
       "       [[0.8534558 , 0.1850003 , 0.39985824, 0.18571115],\n",
       "        [1.3174896 , 0.87760544, 0.6591519 , 1.0132732 ],\n",
       "        [2.2406812 , 1.8455182 , 1.0867497 , 1.6781182 ]],\n",
       "\n",
       "       [[0.368649  , 0.01983261, 0.14453244, 0.1162864 ],\n",
       "        [0.4135641 , 0.6818757 , 0.25695026, 1.0381715 ],\n",
       "        [0.72543   , 1.0716959 , 0.5626936 , 1.6407071 ]],\n",
       "\n",
       "       [[0.12008417, 0.20115614, 0.93636894, 0.46159244],\n",
       "        [1.0369431 , 0.5200989 , 1.2639761 , 0.97124004],\n",
       "        [1.8990936 , 0.9364083 , 1.6754327 , 1.3866991 ]],\n",
       "\n",
       "       [[0.481802  , 0.10488498, 0.9485537 , 0.70350873],\n",
       "        [0.8788043 , 0.65142655, 1.6632828 , 1.5773771 ],\n",
       "        [1.4726118 , 1.2246094 , 1.9988285 , 2.5490787 ]],\n",
       "\n",
       "       [[0.5662435 , 0.31783056, 0.1864289 , 0.36925662],\n",
       "        [1.247858  , 0.9231796 , 0.28073692, 0.6784297 ],\n",
       "        [1.4033777 , 1.1701169 , 0.8308277 , 0.87084866]],\n",
       "\n",
       "       [[0.41028118, 0.03809142, 0.40957248, 0.7718166 ],\n",
       "        [1.238506  , 0.577742  , 0.62828207, 1.052159  ],\n",
       "        [1.936253  , 0.60738766, 0.9841944 , 1.7637898 ]],\n",
       "\n",
       "       [[0.8924757 , 0.09307694, 0.96374404, 0.3464098 ],\n",
       "        [1.8794639 , 0.18365872, 1.2509778 , 1.140793  ],\n",
       "        [2.1172588 , 0.6300105 , 2.1470428 , 2.1311703 ]],\n",
       "\n",
       "       [[0.15579033, 0.4594922 , 0.17970431, 0.19183934],\n",
       "        [0.19597077, 0.5362154 , 0.19988954, 0.38290274],\n",
       "        [0.7524748 , 1.0519221 , 0.76595306, 0.5257962 ]]], dtype=float32)>"
      ]
     },
     "execution_count": 52,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 实现一个动态rnn\n",
    "batch_size = 32\n",
    "seq_len = 3\n",
    "feature_size=4\n",
    "# rnn步，输入与状态叠加\n",
    "def rnn_step(inputs, state):\n",
    "    return inputs + state\n",
    "\n",
    "@tf.function\n",
    "def dynamic_rnn(rnn_step, input_data, initial_state):\n",
    "    # [batch, time, features] -> [time, batch, features]\n",
    "    input_data = tf.transpose(input_data, [1, 0, 2])  # 每个时间维度，都是整个batch数据喂入\n",
    "    max_seq_len = input_data.shape[0]\n",
    "    \n",
    "    # 保存循环中的状态，必须使用tf.TensorArray\n",
    "    states = tf.TensorArray(tf.float32, size=max_seq_len)\n",
    "    state = initial_state\n",
    "    # 迭代时间步\n",
    "    for i in tf.range(max_seq_len):\n",
    "        state = rnn_step(input_data[i], state)\n",
    "        states = states.write(i, state)\n",
    "    # 把 batch_size重新换到前面\n",
    "    return tf.transpose(states.stack(), [1, 0, 2])\n",
    "  \n",
    "    \n",
    "dynamic_rnn(rnn_step,\n",
    "            tf.random.uniform([batch_size, seq_len, feature_size]),\n",
    "            tf.zeros([batch_size, feature_size]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FIq22kd-QjdU"
   },
   "source": [
    "与此同时tf.cond，tf.while_loop还带有一些细微之处。 - 由于循环可以执行0次，因此必须在循环上方初始化在while_loop下游使用的所有张量 - 所有循环变量的形状/ dtypes必须与每次迭代保持一致"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 289
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1137,
     "status": "ok",
     "timestamp": 1560069233709,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "OpvijyewXioF",
    "outputId": "3cb3ba34-30a2-4207-8333-bfb65e4f2d1c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'ValueError'>: in converted code:\n",
      "\n",
      "    <ipython-input-53-05437e37672a>:3 buggy_loop_var_uninitialized  *\n",
      "        for i in tf.range(3):\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt\n",
      "        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:125 _known_len_tf_for_stmt\n",
      "        _disallow_undefs_into_loop(*init_state)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:50 _disallow_undefs_into_loop\n",
      "        tuple(s.symbol_name for s in undefined)))\n",
      "\n",
      "    ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=2062, shape=(), dtype=int32, numpy=2>"
      ]
     },
     "execution_count": 53,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def buggy_loop_var_uninitialized():\n",
    "    for i in tf.range(3):\n",
    "        x = i  # 必须在循环上方初始化好x\n",
    "    return x\n",
    "\n",
    "@tf.function\n",
    "def f():\n",
    "    x = tf.constant(0)\n",
    "    for i in tf.range(3):\n",
    "        x = i\n",
    "    return x\n",
    "\n",
    "with assert_raises(ValueError):\n",
    "    buggy_loop_var_uninitialized()\n",
    "f()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OGkKNLOAX90B"
   },
   "source": [
    "循环时 变量的类型不能改变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 70
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1083,
     "status": "ok",
     "timestamp": 1560069367284,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "z5RAJ-OpX7fr",
    "outputId": "3bc7eccc-0b42-40bc-a147-cf01336aaba0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Input 1 of node while/merge/_10 was passed int32 from while/next_iteration/_28:0 incompatible with expected float. [Op:__inference_buggy_loop_type_changes_2119]\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def buggy_loop_type_changes():\n",
    "    x = tf.constant(0, dtype=tf.float32)\n",
    "    for i in tf.range(3): # Yields tensors of type tf.int32...\n",
    "        x = i\n",
    "    return x\n",
    "\n",
    "with assert_raises(tf.errors.InvalidArgumentError):\n",
    "    buggy_loop_type_changes()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L-mfJKvNPBGN"
   },
   "source": [
    "循环时变量形状也不能改变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 474
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1343,
     "status": "ok",
     "timestamp": 1560069506649,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "7q_6LL4-YwkE",
    "outputId": "a88ed062-24b3-4a69-c349-2aa7bbee361a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Caught expected exception \n",
      "  <class 'ValueError'>: in converted code:\n",
      "\n",
      "    <ipython-input-55-74d839116efa>:4 buggy_concat  *\n",
      "        for i in tf.range(5):\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt\n",
      "        return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:156 _known_len_tf_for_stmt\n",
      "        opts=dict(maximum_iterations=n))\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:327 _tf_while_stmt\n",
      "        retval = control_flow_ops.while_loop(test, body, init_state, **opts)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2646 while_loop\n",
      "        return_same_structure=return_same_structure)\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:213 while_loop\n",
      "        len_orig_loop_vars], expand_composites=True))\n",
      "    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:869 _check_shapes_compat\n",
      "        \"specify a less-specific shape.\" % (input_t.name, shape, t.shape))\n",
      "\n",
      "    ValueError: Input tensor 'ones:0' enters the loop with shape (0, 10), but has shape (1, 10) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=2240, shape=(5, 10), dtype=float32, numpy=\n",
       "array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n",
       "       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>"
      ]
     },
     "execution_count": 55,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def buggy_concat():\n",
    "    x = tf.ones([0, 10])\n",
    "    for i in tf.range(5):\n",
    "        x = tf.concat([x, tf.ones([1, 10])], axis=0)  # 循环时变量形状不能改变\n",
    "    return x\n",
    "\n",
    "with assert_raises(ValueError):\n",
    "    buggy_concat()\n",
    "    \n",
    "@tf.function\n",
    "def concat_with_padding():\n",
    "    x = tf.zeros([5, 10])\n",
    "    for i in tf.range(5):\n",
    "        x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)\n",
    "        x.set_shape([5, 10])\n",
    "    return x\n",
    "\n",
    "concat_with_padding()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HM0DYSlOLL_P"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "006-tf_function_and_autograph",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
